
function dXdt = N1DE_refactored(t, X, p)

% -------- Input hygiene -------------------------------------------------
if nargin < 3 || ~isstruct(p); error('Provide parameter struct p.'); end
validateParams(p);
persistent IDX last_sizes
sizes = [p.N_RA, p.N_INT];
if isempty(IDX) || isempty(last_sizes) || any(last_sizes ~= sizes)
    IDX = build_indices(p.N_RA, p.N_INT);
    last_sizes = sizes;
end

% -------- Unpack state --------------------------------------------------
% RA blocks (vectorized)
VRA   = X(IDX.RA.V);
nRA   = X(IDX.RA.n);
hRA   = X(IDX.RA.h);
eRA   = X(IDX.RA.e);
CaiRA = X(IDX.RA.Cai);
sAMin = X(IDX.RA.sAMin);   % incoming AMPA gates onto RA (external/local if used)

% RA->RA chain gates (from RA_i to RA_{i+1})
s_chain = X(IDX.RA_CHAIN);

% INT blocks (scaffold)
VINT   = X(IDX.INT.V);
nINT   = X(IDX.INT.n);
hINT   = X(IDX.INT.h);
rTINT  = X(IDX.INT.rT);
CaiINT = X(IDX.INT.Cai);
rfINT  = X(IDX.INT.rf);
rsINT  = X(IDX.INT.rs);
hpINT  = X(IDX.INT.hp);

% -------- Defaults for applied currents ---------------------------------
if ~isfield(p,'Iapp_RA') || isempty(p.Iapp_RA),  p.Iapp_RA  = zeros(1,p.N_RA);  end
if ~isfield(p,'Iapp_INT')|| isempty(p.Iapp_INT), p.Iapp_INT = zeros(1,p.N_INT); end
Iapp_RA  = p.Iapp_RA(:);
Iapp_INT = p.Iapp_INT(:);

% --------- RA gating variables ------------------------------------------
% Sigmoid helpers
sinf = @(V,th,sig) 1./(1 + exp((V - th)./sig));
cosh_safe = @(x) (exp(x)+exp(-x))/2;

s_s   = sinf(VRA, p.thetas, p.sigmas);
n_inf = sinf(VRA, p.thetan, p.sigman);
tau_n = p.taunbar ./ cosh_safe((VRA - p.thetan) ./ (2*p.sigman));
m_inf = sinf(VRA, p.thetam, p.sigmam);

alpha_h = 0.128 .* exp(-(VRA + 50)./18);
beta_h  = 4 ./ (1 + exp(-(VRA + 27)./5));
h_inf   = alpha_h ./ (alpha_h + beta_h);

a_inf = sinf(VRA, p.thetaa, p.sigmaa);  % A-type K+ gate proxy
e_inf = sinf(VRA, p.thetae, p.sigmae);  % adaptation gate proxy

k_inf = (CaiRA.^2) ./ (CaiRA.^2 + p.ks^2); % SK activation by Ca
Tgate = p.Tmax ./ (1 + exp(-(VRA - p.VT)./p.Kps)); % transmitter release gate

% --------- RA currents (vectorized) -------------------------------------
% Units: adopt the same conventions as legacy code (currents in nA).
% Note: g* can be vectors (heterogeneous cells) or scalars.
iCaL = p.gCa  .* (s_s.^2) .* VRA .* ( p.Ca_ex ./ (1 - exp((2*VRA)./p.RTF)) );
iK   = p.gKRA .* (nRA.^4) .* (VRA - p.VK);
iNa  = p.gNaRA.* (m_inf.^3).* hRA .* (VRA - p.VNa);
iA   = p.gARA .*  a_inf   .* eRA .* (VRA - p.VK);
iSK  = p.gSKRA.*  k_inf   .* (VRA - p.VK);
iL   = p.gL   .* (VRA - p.VL);

% AMPA from (i -> i+1) chain onto postsynaptic RA(2:end)
iAM_chain = zeros(p.N_RA,1);
if p.N_RA > 1
    V_post = VRA(2:end);
    iAM_chain(2:end) = p.gAMrr .* s_chain .* (V_post - p.VAM);
end

% Additional incoming AMPA (optional local/external) per cell
iAM_in = p.gAMrr .* sAMin .* (VRA - p.VAM);

% Net membrane current and voltage dynamics (sign convention: Cm dV/dt = -I_ion + Iapp)
I_ion_RA = iCaL + iK + iNa + iA + iSK + iL + iAM_chain + iAM_in;
dVRA     = (-I_ion_RA + Iapp_RA) ./ p.C_RA;

% --------- RA gating / Ca dynamics --------------------------------------
dnRA = (n_inf - nRA) ./ tau_n;
dhRA = (h_inf - hRA) ./ p.tauh;
deRA = (e_inf - eRA) ./ p.taue;

% Simple Ca2+ pool: d[Ca]/dt = -bCa * iCaL - kCa * Cai  (choose signs to match legacy)
if ~isfield(p,'bCa'), p.bCa = 1e-3; end
if ~isfield(p,'kCa'), p.kCa = 1e-2; end
dCaiRA = -p.bCa .* iCaL - p.kCa .* CaiRA;

% AMPA gates: ds/dt = alpha * Tgate * (1 - s) - delta * s
ds_chain = zeros(max(p.N_RA-1,0),1);
if p.N_RA > 1
    T_pre = Tgate(1:end-1); % transmitter from presynaptic RA(i)
    s     = s_chain;
    ds_chain = p.arAM .* T_pre .* (1 - s) - p.adAM .* s;
end

% Incoming AMPA gates per RA (if used for external/local drive)
s = sAMin;
dsAMin = p.arAM .* Tgate .* (1 - s) - p.adAM .* s;

% --------- INT scaffold (illustrative; extend similarly) -----------------
% Minimal leak + Na/K + h, with placeholders for CaT/H kinetics; set g's to zero if unused.
m_inf_INT = sinf(VINT, p.thetam, p.sigmam);
iK_INT  = p.gKINT  .* (nINT.^4) .* (VINT - p.VK);
iNa_INT = p.gNaINT .* (m_inf_INT.^3).* hINT .* (VINT - p.VNa);
iH_INT  = p.gHINT  .* hpINT .* (VINT - p.VH);
iL_INT  = p.gL     .* (VINT - p.VL);

I_ion_INT = iK_INT + iNa_INT + iH_INT + iL_INT; % add CaT/SK/etc. as needed
dVINT     = (-I_ion_INT + Iapp_INT) ./ p.C_INT;

% Simple first-order gates as placeholders
tau_n_INT = p.taunbar ./ cosh_safe((VINT - p.thetan)./(2*p.sigman));
n_inf_INT = sinf(VINT, p.thetan, p.sigman);
h_inf_INT = 1./(1+exp(-(VINT+45)/5));  % placeholder
tau_h_INT = p.tauh*ones(p.N_INT,1);

dnINT = (n_inf_INT - nINT) ./ tau_n_INT;
dhINT = (h_inf_INT - hINT) ./ tau_h_INT;

% Ca pool placeholders
if ~isfield(p,'bCa_INT'), p.bCa_INT = 1e-3; end
if ~isfield(p,'kCa_INT'), p.kCa_INT = 1e-2; end
dCaiINT = -p.bCa_INT .* 0 - p.kCa_INT .* CaiINT; % put your Ca currents instead of 0

% Other INT gates (rf, rs, hp) as placeholders (copy real kinetics from your legacy code)
drTINT = zeros(p.N_INT,1);
drfINT = zeros(p.N_INT,1);
drsINT = zeros(p.N_INT,1);
dhpINT = (sinf(VINT, -70, 7) - hpINT) ./ p.tauhpbar;

% --------- Pack derivative ------------------------------------------------
dXdt = zeros(size(X));

% RA
dXdt(IDX.RA.V)    = dVRA;
dXdt(IDX.RA.n)    = dnRA;
dXdt(IDX.RA.h)    = dhRA;
dXdt(IDX.RA.e)    = deRA;
dXdt(IDX.RA.Cai)  = dCaiRA;
dXdt(IDX.RA.sAMin)= dsAMin;

% RA chain
dXdt(IDX.RA_CHAIN)= ds_chain;

% INT
dXdt(IDX.INT.V)   = dVINT;
dXdt(IDX.INT.n)   = dnINT;
dXdt(IDX.INT.h)   = dhINT;
dXdt(IDX.INT.rT)  = drTINT;
dXdt(IDX.INT.Cai) = dCaiINT;
dXdt(IDX.INT.rf)  = drfINT;
dXdt(IDX.INT.rs)  = drsINT;
dXdt(IDX.INT.hp)  = dhpINT;

% -------------------------------------------------------------------------
% Nested helpers
    function validateParams(p)
        req = {'N_RA','N_INT','VL','VK','VH','VNa','VAM','C_RA','C_INT', ...
               'gL','gCa','gKRA','gNaRA','gSKRA','gARA', ...
               'gNaINT','gKINT','gHINT','gNap', ...
               'Ca_ex','RTF','thetam','thetan','thetas','thetaa','thetae', ...
               'sigmam','sigman','sigmas','sigmaa','sigmae', ...
               'taunbar','taue','tauh','tauhpbar','ks','Tmax','VT','Kps', ...
               'gAMrr','arAM','adAM','arGA','adGA'};
        for k=1:numel(req)
            if ~isfield(p, req{k})
                error('Parameter struct missing field: %s', req{k});
            end
        end
    end

    function IDX = build_indices(N_RA, N_INT)
        % RA: [V, n, h, e, Cai, sAMin] per cell
        RA_per = 6;
        n_RA_states = N_RA*RA_per;
        RA_idx = reshape(1:n_RA_states, RA_per, N_RA);
        IDX.RA.V     = RA_idx(1,:).';
        IDX.RA.n     = RA_idx(2,:).';
        IDX.RA.h     = RA_idx(3,:).';
        IDX.RA.e     = RA_idx(4,:).';
        IDX.RA.Cai   = RA_idx(5,:).';
        IDX.RA.sAMin = RA_idx(6,:).';

        % RA chain gates: N_RA-1 synapses
        if N_RA>1
            IDX.RA_CHAIN = (n_RA_states + (1:(N_RA-1))).';
        else
            IDX.RA_CHAIN = zeros(0,1);
        end

        % INT: [V, n, h, rT, Cai, rf, rs, hp] per cell
        INT_per = 8;
        start = n_RA_states + max(N_RA-1,0);
        INT_block = start + (1:(N_INT*INT_per));
        INT_idx = reshape(INT_block, INT_per, N_INT);
        IDX.INT.V   = INT_idx(1,:).';
        IDX.INT.n   = INT_idx(2,:).';
        IDX.INT.h   = INT_idx(3,:).';
        IDX.INT.rT  = INT_idx(4,:).';
        IDX.INT.Cai = INT_idx(5,:).';
        IDX.INT.rf  = INT_idx(6,:).';
        IDX.INT.rs  = INT_idx(7,:).';
        IDX.INT.hp  = INT_idx(8,:).';
    end
end
